{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prompt2Prompt Pipeline\n",
    "\n",
    "Prompt2Prompt allows the following edits:\n",
    "\n",
    "1. ReplaceEdit (change words in prompt)\n",
    "2. ReplaceEdit with local blend (change words in prompt, keep image part unrelated to changes constant)\n",
    "3. RefineEdit (add words to prompt)\n",
    "4. RefineEdit with local blend (add words to prompt, keep image part unrelated to changes constant)\n",
    "5. ReweightEdit (modulate importance of words)\n",
    "\n",
    "Abbreviated examples for the other edits:\n",
    "\n",
    "ReplaceEdit with local blend\n",
    "\n",
    "```python\n",
    "prompts = [\"A turtle playing with a ball\",\n",
    "           \"A monkey playing with a ball\"]\n",
    "\n",
    "cross_attention_kwargs = {\n",
    "    \"edit_type\": \"replace\",\n",
    "    \"cross_replace_steps\": 0.4,\n",
    "    \"self_replace_steps\": 0.4,\n",
    "    \"local_blend_words\": [\"turtle\", \"monkey\"]\n",
    "}\n",
    "```\n",
    "\n",
    "RefineEdit\n",
    "\n",
    "```python\n",
    "prompts = [\"A turtle\",\n",
    "           \"A turtle in a forest\"]\n",
    "\n",
    "cross_attention_kwargs = {\n",
    "    \"edit_type\": \"refine\",\n",
    "    \"cross_replace_steps\": 0.4,\n",
    "    \"self_replace_steps\": 0.4,\n",
    "}\n",
    "```\n",
    "\n",
    "RefineEdit with local blend\n",
    "\n",
    "```python\n",
    "prompts = [\"A turtle\",\n",
    "           \"A turtle in a forest\"]\n",
    "\n",
    "cross_attention_kwargs = {\n",
    "    \"edit_type\": \"refine\",\n",
    "    \"cross_replace_steps\": 0.4,\n",
    "    \"self_replace_steps\": 0.4,\n",
    "    \"local_blend_words\": [\"in\", \"a\" , \"forest\"]\n",
    "}\n",
    "```\n",
    "\n",
    "ReweightEdit\n",
    "\n",
    "```python\n",
    "prompts = [\"A smiling turtle\"] * 2\n",
    "\n",
    "edit_kcross_attention_kwargswargs = {\n",
    "    \"edit_type\": \"reweight\",\n",
    "    \"cross_replace_steps\": 0.4,\n",
    "    \"self_replace_steps\": 0.4,\n",
    "    \"equalizer_words\": [\"smiling\"],\n",
    "    \"equalizer_strengths\": [5]\n",
    "}\n",
    "```\n",
    "\n",
    "Side note: See this [GitHub gist](https://gist.github.com/UmerHA/b65bb5fb9626c9c73f3ade2869e36164) if you want to visualize the attention maps. This script was contributed by [Umer Adil](https://github.com/UmerHA) and the notebook by [Parag Ekbote](https://github.com/ParagEkbote)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "cannot get type annotation for Parameter vae of <class 'diffusers_modules.git.pipeline_prompt2prompt.Prompt2PromptPipeline'>.\n",
      "cannot get type annotation for Parameter text_encoder of <class 'diffusers_modules.git.pipeline_prompt2prompt.Prompt2PromptPipeline'>.\n",
      "cannot get type annotation for Parameter tokenizer of <class 'diffusers_modules.git.pipeline_prompt2prompt.Prompt2PromptPipeline'>.\n",
      "cannot get type annotation for Parameter unet of <class 'diffusers_modules.git.pipeline_prompt2prompt.Prompt2PromptPipeline'>.\n",
      "cannot get type annotation for Parameter scheduler of <class 'diffusers_modules.git.pipeline_prompt2prompt.Prompt2PromptPipeline'>.\n",
      "cannot get type annotation for Parameter safety_checker of <class 'diffusers_modules.git.pipeline_prompt2prompt.Prompt2PromptPipeline'>.\n",
      "cannot get type annotation for Parameter feature_extractor of <class 'diffusers_modules.git.pipeline_prompt2prompt.Prompt2PromptPipeline'>.\n",
      "cannot get type annotation for Parameter image_encoder of <class 'diffusers_modules.git.pipeline_prompt2prompt.Prompt2PromptPipeline'>.\n",
      "cannot get type annotation for Parameter requires_safety_checker of <class 'diffusers_modules.git.pipeline_prompt2prompt.Prompt2PromptPipeline'>.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3e47bb32a31d4d6fb2e0fc2a197ac074",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "431ad3ea3ee04839bba22a7c67a883c4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from diffusers import DiffusionPipeline\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "# Load the pipeline with custom prompt-to-prompt\n",
    "pipe = DiffusionPipeline.from_pretrained(\n",
    "    \"CompVis/stable-diffusion-v1-4\", \n",
    "    custom_pipeline=\"pipeline_prompt2prompt\"\n",
    ").to(\"cuda\")\n",
    "\n",
    "# Prompts for image generation\n",
    "prompts = [\n",
    "    \"A turtle playing with a ball\",\n",
    "    \"A monkey playing with a ball\"\n",
    "]\n",
    "\n",
    "# Custom attention settings\n",
    "cross_attention_kwargs = {\n",
    "    \"edit_type\": \"replace\",\n",
    "    \"cross_replace_steps\": 0.4,\n",
    "    \"self_replace_steps\": 0.4\n",
    "}\n",
    "\n",
    "# Generate images\n",
    "outputs = pipe(\n",
    "    prompt=prompts,\n",
    "    height=512,\n",
    "    width=512,\n",
    "    num_inference_steps=50,\n",
    "    cross_attention_kwargs=cross_attention_kwargs\n",
    ")\n",
    "\n",
    "# Save each image explicitly without a loop\n",
    "outputs.images[0].save(\"output_image_0.png\")"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
